{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Traversing Expression Trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import pymbolic.primitives as p\n",
    "x = p.Variable(\"x\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Power(Sum((Variable('x'), 3)), 5)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "u = (x+3)**5\n",
    "u"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Traversal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Many options to walk this expression.\n",
    "\n",
    "* One big recursive function with many `if isinstance` checks\n",
    "* \"Visitor pattern\" -> Define a class, dispatch to a different method for each node type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'map_sum'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p.Sum.mapper_method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from pymbolic.mapper import WalkMapper\n",
    "\n",
    "class MyMapper(WalkMapper):\n",
    "    def map_sum(self, expr):\n",
    "        print(\"sum\", expr.children)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Power(Sum((Variable('x'), 3)), 5)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "u = (x+3)**5\n",
    "u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum (Variable('x'), 3)\n"
     ]
    }
   ],
   "source": [
    "mymapper = MyMapper()\n",
    "mymapper(u)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recursive Traversal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What if there is another sum nested inside our existing one?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sum((Power(Sum((Variable('x'), 3)), 5), 5))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "u = (x+3)**5 + 5\n",
    "u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum (Power(Sum((Variable('x'), 3)), 5), 5)\n"
     ]
    }
   ],
   "source": [
    "mymapper(u)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What do you notice? Is something missing?\n",
    "\n",
    "Improve implementation as `MyMapper2`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#clear\n",
    "from pymbolic.mapper import WalkMapper\n",
    "\n",
    "class MyMapper2(WalkMapper):\n",
    "    def map_sum(self, expr):\n",
    "        print(\"sum\", expr.children)\n",
    "\n",
    "        for ch in expr.children:\n",
    "            self.rec(ch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum (Power(Sum((Variable('x'), 3)), 5), 5)\n",
      "sum (Variable('x'), 3)\n"
     ]
    }
   ],
   "source": [
    "mymapper2 = MyMapper2()\n",
    "mymapper2(u)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mapper Inheritance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Above: What about `map_variable`? `map_power`?\n",
    "* Mappers inherit all non-overridden behavior from their superclasses.\n",
    "\n",
    "This makes it easy to *inherit a base behavior* and then selectively change a few pieces."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mappers with Values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Mappers do more than just *traverse*\n",
    "* They can also return a value\n",
    "    * What type? Any desired one.\n",
    "    \n",
    "For example: Could return a string."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from pymbolic.mapper import RecursiveMapper\n",
    "class MyStringifier(RecursiveMapper):\n",
    "    def map_sum(self, expr):\n",
    "        return \"+\".join(self.rec(ch) for ch in expr.children)\n",
    "    \n",
    "    def map_product(self, expr):\n",
    "        return \"*\".join(self.rec(ch) for ch in expr.children)\n",
    "    \n",
    "    def map_variable(self, expr):\n",
    "        return expr.name\n",
    "    \n",
    "    def map_constant(self, expr):\n",
    "        return str(expr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'x*5+x*7'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#clear\n",
    "u = (x * 5)+(x * 7)\n",
    "\n",
    "\n",
    "\n",
    "mystrifier = MyStringifier()\n",
    "\n",
    "mystrifier(u)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mappers can also return another expression. `IdentityMapper` is a base that returns an identical (deep) copy of an expression:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "#clear\n",
    "from pymbolic.mapper import IdentityMapper\n",
    "\n",
    "idmap = IdentityMapper()\n",
    "u2 = idmap(u)\n",
    "print(u2 == u)\n",
    "print(u2 is u)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Term Rewriting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`IdentityMapper` can be used as a convenient base for term rewriting.\n",
    "\n",
    "As a very simple example, let us\n",
    "\n",
    "* Change the name of all variables by appending a prime\n",
    "* Change all products to sums"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#clear\n",
    "class MyIdentityMapper(IdentityMapper):\n",
    "    def map_variable(self, expr):\n",
    "        return p.Variable(expr.name + \"'\")\n",
    "\n",
    "    def map_product(self, expr):\n",
    "        return p.Sum(tuple(self.rec(ch) for ch in expr.children))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x' + 3 + (x' + 17)**3\n"
     ]
    }
   ],
   "source": [
    "u = (x*3)*(x+17)**3\n",
    "\n",
    "myidmap = MyIdentityMapper()\n",
    "print(myidmap(u))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.0+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
